package com.zh.dev.server;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

import org.springframework.stereotype.Component;

import com.zh.dev.util.RsaUtil;

import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.symmetric.RC4;
import lombok.extern.slf4j.Slf4j;

@Component
@Slf4j
@ServerEndpoint("/ws")
public class WSTServer {
	
	
	private Session session;
	
	private AtomicInteger step = new AtomicInteger(0);
	
	private String k;
	private String host;
	private int port;
	
	private Socket socket;
	
	private OutputStream out;
	
	private RC4 rc;
	
	private InputStream input;
	
	//private Long prevTime;
	
	@OnOpen
	public void onOpen(Session session) {
		session.setMaxIdleTimeout(5*60*1000);
		this.session = session;
		this.step.set(0);
		log.info("连接建立");
	}
	
	@OnMessage
	public void strMsg(String msg) {
		
		switch(step.get()) {
			
		case 0:
			//密钥交换
			swKey(msg);
			break;
		case 1:
			//建立隧道
			buildLink(msg);
			break;
		}
		
		log.info("心跳包：" + this.session.getId());
	}
	
	@OnMessage 
	public void byteMsg(byte[] b) {
		if( this.step.get() == 2 ) {
			byte[] sendByte = this.rc.crypt(b);
			try {
				this.out.write(sendByte);
//				flushTime();
			} catch (IOException e) {
				e.printStackTrace();
				close();
			}
		}
	}
	
	@OnClose
	public void onClose() {
		log.info("通道关闭");
		close();
	}
	@OnError
	public void onError(Throwable ex) {
		close();
	}
	
	private void buildLink(String msg) {
		
		try {
			
			String url = this.rc.decrypt(msg);
			List<String> pms = StrUtil.split(url, ":");
			
			this.host = pms.get(0);
			this.port = Integer.valueOf(pms.get(1));
			
			this.socket = new Socket(host, this.port);
			
			this.out = socket.getOutputStream();
			
			new Thread(()->{
				
				log.info("开始传输数据");
				
				try{
					
					input = socket.getInputStream();
					
					while(socket.isConnected()) {
						
						byte[] b = new byte[1024];
						int len = 0;
						while( ( len = input.read(b) ) > 0) {
							byte[] send = new byte[len];
							System.arraycopy(b, 0, send, 0, len);
							this.session.getBasicRemote().sendBinary(ByteBuffer.wrap(this.rc.crypt(send)));
//							flushTime();
						}
						
					}
					
				}catch (IOException e) {
					close();
				}
				
				log.info("数据传输结束");
				
			}).start();
			
			sendMsg("success");
			this.step.addAndGet(1);
			
			log.info("隧道建立成功");
		}catch (Exception e) {
			e.printStackTrace();
			log.info("隧道构建失败");
			close();
		}
		
	}
	
	private void swKey(String msg) {
		try {
			//获取客户端发来的密钥
			this.k = RsaUtil.decodeHex(msg);
			this.rc = new RC4(this.k);
			this.step.addAndGet(1);
			sendMsg("success");
			log.info("密钥获取完毕");
		}catch (Exception e) {
			// 密钥获取失败，中断连接
			e.printStackTrace();
			log.info("密钥获取失败");
			close();
		}
	}
	
	private void sendMsg(String msg) {
		if(this.rc != null) {
			try {
				this.session.getBasicRemote().sendText(this.rc.encryptHex(msg));
			} catch (IOException e) {
				e.printStackTrace();
				close();
			}
		}
	}
	
	private void close() {
		try {
			
			if(input != null) {
				input.close();
			}
			
			if(out != null) {
				out.close();
			}
			
			if(this.socket != null) {
				this.socket.close();
			}
			
			session.close();
			
		} catch (IOException e) {}
	}
	
//	private void flushTime() {
//		//刷新定时
//		this.prevTime = DateUtil.current();
//	}

}
